{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4a36abc9-47b6-4e9a-8d2c-330e64012db2",
   "metadata": {},
   "source": [
    "# YOLOV8预训练模型预测-Python API-视频\n",
    "\n",
    "同济子豪兄：https://space.bilibili.com/1900783"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47341443-35bf-4a0c-b214-41a82af6ee0f",
   "metadata": {},
   "source": [
    "## 导入工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "abda64d2-c108-409b-9bc1-9c3b637decbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda:0\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "from ultralytics import YOLO\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "# 有 GPU 就用 GPU，没有就用 CPU\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print('device:', device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d27ac8d-f8ff-41b5-88e2-054c8786f500",
   "metadata": {},
   "source": [
    "## 载入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "42e5ecf0-bb31-4d5d-bf63-53e8207740fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = YOLO('checkpoint/Triangle_215_yolov8n_pose_pretrain.pt')\n",
    "model = YOLO('checkpoint/Triangle_215_yolov8x_p6_pretrain.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3af2e447-f9a0-4692-b88b-706abc07b15f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7a7ba207-3066-4080-9be3-b609acf37717",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 切换计算设备\n",
    "# model.cpu()  # CPU\n",
    "# model.cuda() # GPU"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb86a424-6df1-4893-ac43-ac4882597ef3",
   "metadata": {},
   "source": [
    "## 可视化配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "83826d9e-7531-4538-8b85-00956593761d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 框（rectangle）可视化配置\n",
    "bbox_color = (150, 0, 0)             # 框的 BGR 颜色\n",
    "bbox_thickness = 2                   # 框的线宽\n",
    "\n",
    "# 框类别文字\n",
    "bbox_labelstr = {\n",
    "    'font_size':1,         # 字体大小\n",
    "    'font_thickness':2,    # 字体粗细\n",
    "    'offset_x':0,          # X 方向，文字偏移距离，向右为正\n",
    "    'offset_y':-10,        # Y 方向，文字偏移距离，向下为正\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f4f56e97-cb45-408b-a4a2-a980b82cc05f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 关键点 BGR 配色\n",
    "kpt_color_map = {\n",
    "    0:{'name':'angle_30', 'color':[255, 0, 0], 'radius':6},      # 30度角点\n",
    "    1:{'name':'angle_60', 'color':[0, 255, 0], 'radius':6},      # 60度角点\n",
    "    2:{'name':'angle_90', 'color':[0, 0, 255], 'radius':6},      # 90度角点\n",
    "}\n",
    "\n",
    "# 点类别文字\n",
    "kpt_labelstr = {\n",
    "    'font_size':1,             # 字体大小\n",
    "    'font_thickness':3,       # 字体粗细\n",
    "    'offset_x':10,             # X 方向，文字偏移距离，向右为正\n",
    "    'offset_y':0,            # Y 方向，文字偏移距离，向下为正\n",
    "}\n",
    "\n",
    "# 骨架连接 BGR 配色\n",
    "skeleton_map = [\n",
    "    {'srt_kpt_id':0, 'dst_kpt_id':1, 'color':[196, 75, 255], 'thickness':2},        # 30度角点-60度角点\n",
    "    {'srt_kpt_id':0, 'dst_kpt_id':2, 'color':[180, 187, 28], 'thickness':2},        # 30度角点-90度角点\n",
    "    {'srt_kpt_id':1, 'dst_kpt_id':2, 'color':[47,255, 173], 'thickness':2},         # 60度角点-90度角点\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eb43857-3d47-4bae-97ac-a0e458cf82c2",
   "metadata": {},
   "source": [
    "## 逐帧处理函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "911460a8-dd8b-4468-9a25-b0c97b819b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_frame(img_bgr):\n",
    "    \n",
    "    '''\n",
    "    输入摄像头画面 bgr-array，输出图像 bgr-array\n",
    "    '''\n",
    "    \n",
    "    results = model(img_bgr, verbose=False) # verbose设置为False，不单独打印每一帧预测结果\n",
    "    \n",
    "    # 预测框的个数\n",
    "    num_bbox = len(results[0].boxes.cls)\n",
    "    \n",
    "    # 预测框的 xyxy 坐标\n",
    "    bboxes_xyxy = results[0].boxes.xyxy.cpu().numpy().astype('uint32') \n",
    "    \n",
    "    # 关键点的 xy 坐标\n",
    "    bboxes_keypoints = results[0].keypoints.cpu().numpy().astype('uint32')\n",
    "    \n",
    "    for idx in range(num_bbox): # 遍历每个框\n",
    "\n",
    "        # 获取该框坐标\n",
    "        bbox_xyxy = bboxes_xyxy[idx] \n",
    "\n",
    "        # 获取框的预测类别（对于关键点检测，只有一个类别）\n",
    "        bbox_label = results[0].names[0]\n",
    "\n",
    "        # 画框\n",
    "        img_bgr = cv2.rectangle(img_bgr, (bbox_xyxy[0], bbox_xyxy[1]), (bbox_xyxy[2], bbox_xyxy[3]), bbox_color, bbox_thickness)\n",
    "\n",
    "        # 写框类别文字：图片，文字字符串，文字左上角坐标，字体，字体大小，颜色，字体粗细\n",
    "        img_bgr = cv2.putText(img_bgr, bbox_label, (bbox_xyxy[0]+bbox_labelstr['offset_x'], bbox_xyxy[1]+bbox_labelstr['offset_y']), cv2.FONT_HERSHEY_SIMPLEX, bbox_labelstr['font_size'], bbox_color, bbox_labelstr['font_thickness'])\n",
    "\n",
    "        bbox_keypoints = bboxes_keypoints[idx] # 该框所有关键点坐标和置信度\n",
    "\n",
    "        # 画该框的骨架连接\n",
    "        for skeleton in skeleton_map:\n",
    "\n",
    "            # 获取起始点坐标\n",
    "            srt_kpt_id = skeleton['srt_kpt_id']\n",
    "            srt_kpt_x = bbox_keypoints[srt_kpt_id][0]\n",
    "            srt_kpt_y = bbox_keypoints[srt_kpt_id][1]\n",
    "\n",
    "            # 获取终止点坐标\n",
    "            dst_kpt_id = skeleton['dst_kpt_id']\n",
    "            dst_kpt_x = bbox_keypoints[dst_kpt_id][0]\n",
    "            dst_kpt_y = bbox_keypoints[dst_kpt_id][1]\n",
    "\n",
    "            # 获取骨架连接颜色\n",
    "            skeleton_color = skeleton['color']\n",
    "\n",
    "            # 获取骨架连接线宽\n",
    "            skeleton_thickness = skeleton['thickness']\n",
    "\n",
    "            # 画骨架连接\n",
    "            img_bgr = cv2.line(img_bgr, (srt_kpt_x, srt_kpt_y),(dst_kpt_x, dst_kpt_y),color=skeleton_color,thickness=skeleton_thickness)  \n",
    "            \n",
    "        # 画该框的关键点\n",
    "        for kpt_id in kpt_color_map:\n",
    "\n",
    "            # 获取该关键点的颜色、半径、XY坐标\n",
    "            kpt_color = kpt_color_map[kpt_id]['color']\n",
    "            kpt_radius = kpt_color_map[kpt_id]['radius']\n",
    "            kpt_x = bbox_keypoints[kpt_id][0]\n",
    "            kpt_y = bbox_keypoints[kpt_id][1]\n",
    "\n",
    "            # 画圆：图片、XY坐标、半径、颜色、线宽（-1为填充）\n",
    "            img_bgr = cv2.circle(img_bgr, (kpt_x, kpt_y), kpt_radius, kpt_color, -1)\n",
    "\n",
    "            # 写关键点类别文字：图片，文字字符串，文字左上角坐标，字体，字体大小，颜色，字体粗细\n",
    "            kpt_label = str(kpt_id) # 写关键点类别 ID（二选一）\n",
    "            # kpt_label = str(kpt_color_map[kpt_id]['name']) # 写关键点类别名称（二选一）\n",
    "            img_bgr = cv2.putText(img_bgr, kpt_label, (kpt_x+kpt_labelstr['offset_x'], kpt_y+kpt_labelstr['offset_y']), cv2.FONT_HERSHEY_SIMPLEX, kpt_labelstr['font_size'], kpt_color, kpt_labelstr['font_thickness'])\n",
    "    \n",
    "    return img_bgr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f169a0c0-35d1-40e2-a11d-5e7ec6c8a9aa",
   "metadata": {},
   "source": [
    "## 视频逐帧处理（模板）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e36d9bd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 视频逐帧处理代码模板\n",
    "# 不需修改任何代码，只需定义process_frame函数即可\n",
    "# 同济子豪兄 2021-7-10\n",
    "\n",
    "def generate_video(input_path='videos/robot.mp4'):\n",
    "    filehead = input_path.split('/')[-1]\n",
    "    output_path = \"out-\" + filehead\n",
    "    \n",
    "    print('视频开始处理',input_path)\n",
    "    \n",
    "    # 获取视频总帧数\n",
    "    cap = cv2.VideoCapture(input_path)\n",
    "    frame_count = 0\n",
    "    while(cap.isOpened()):\n",
    "        success, frame = cap.read()\n",
    "        frame_count += 1\n",
    "        if not success:\n",
    "            break\n",
    "    cap.release()\n",
    "    print('视频总帧数为',frame_count)\n",
    "    \n",
    "    # cv2.namedWindow('Crack Detection and Measurement Video Processing')\n",
    "    cap = cv2.VideoCapture(input_path)\n",
    "    frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
    "\n",
    "    # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))\n",
    "    # fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
    "    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
    "    fps = cap.get(cv2.CAP_PROP_FPS)\n",
    "\n",
    "    out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))\n",
    "    \n",
    "    # 进度条绑定视频总帧数\n",
    "    with tqdm(total=frame_count-1) as pbar:\n",
    "        try:\n",
    "            while(cap.isOpened()):\n",
    "                success, frame = cap.read()\n",
    "                if not success:\n",
    "                    break\n",
    "\n",
    "                # 处理帧\n",
    "                # frame_path = './temp_frame.png'\n",
    "                # cv2.imwrite(frame_path, frame)\n",
    "                try:\n",
    "                    frame = process_frame(frame)\n",
    "                except Exception as error:\n",
    "                    print('报错！', error)\n",
    "                    pass\n",
    "                \n",
    "                if success == True:\n",
    "                    # cv2.imshow('Video Processing', frame)\n",
    "                    out.write(frame)\n",
    "\n",
    "                    # 进度条更新一帧\n",
    "                    pbar.update(1)\n",
    "\n",
    "                # if cv2.waitKey(1) & 0xFF == ord('q'):\n",
    "                    # break\n",
    "        except:\n",
    "            print('中途中断')\n",
    "            pass\n",
    "\n",
    "    cv2.destroyAllWindows()\n",
    "    out.release()\n",
    "    cap.release()\n",
    "    print('视频已保存', output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eccb873a-896b-4331-8dd1-c15ec3622f4f",
   "metadata": {},
   "source": [
    "## 视频预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "75a7c2bb-a2c3-4026-b858-be8b18cd0b81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "视频开始处理 videos/triangle_7.mp4\n",
      "视频总帧数为 220\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 219/219 [00:16<00:00, 13.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "视频已保存 out-triangle_7.mp4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "generate_video(input_path='videos/triangle_7.mp4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e539a260-3e9a-4fc8-9091-aa8af66332d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
